Skip to content

Conversation

ChrisRackauckas
Copy link
Member

No description provided.

@ChrisRackauckas
Copy link
Member Author

MWE of the stuff ODE solver problem:

using Flux, Zygote, ForwardDiff
import ForwardDiff: Dual

y = Float32[1.2449092, 0.26629877]
p = Float32[0.14421135, -0.006150621, 0.0393358, 0.138404, 0.23749629, 0.06463469, -0.029445898, -0.33279192, -0.094798535, -0.257304, -0.3355695, -0.1959481, 0.12938745, 0.14058144, 0.32916018, -0.23945713, -0.18813372, -0.14978944, 0.18167028, -0.22040617, -0.16580728, -0.09962158, -0.12878253, 0.24638167, -0.03310824, 0.07440266, 0.03885393, 0.27210253, -0.053823117, -0.14623246, -0.034661364, 0.049675502, 0.16398363, -0.30591217, 0.18999895, -0.26469624, 0.28702003, 0.20897748, -0.32785562, -0.100942954, -0.32169065, 0.21481845, 0.09703442, 0.30915034, 0.09057236, -0.15546058, -0.24163458, -0.13516225, -0.06676043, -0.1966813, 0.12077151, 0.056194287, -0.16526969, -0.2222915, -0.19672059, -0.034455374, -0.24578816, 0.18768719, -0.23405759, 0.046496972, -0.258523, 0.058912445, 0.042145796, -0.13487151, -0.2644665, -0.33397835, -0.29189992, 0.13996881, -0.21306355, 0.15383047, 0.15763333, -0.27050394, 0.3312636, 0.32032087, -0.24478982, -0.1096856, 0.12329024, -0.33420125, -0.1529397, 0.013263283, -0.0321317, -0.28141057, -0.058830447, -0.033951838, -0.18657157, 0.20016932, 0.1548164, 0.028861579, 0.16291597, 0.22635445, -0.090969354, 0.1766979, -0.31983075, 0.07219792, 0.23401073, -0.07207494, 0.24587327, 0.26736307, -0.23342982, 0.08657169, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.19381283, 0.26416284, 0.15891802, -0.12095787, 0.1411279, -0.027013905, 0.30863065, 0.122840405, 0.21558835, -0.15353091, 0.29037133, 0.15146789, 0.045571294, 0.31934103, -0.013038424, 0.19338936, -0.08277726, 0.23436436, -0.25519383, 0.19290958, -0.07854137, 0.25316665, 0.068068326, -0.25551397, 0.29491913, -0.17783256, -0.28973922, -0.102145046, -0.052218888, -0.14526269, -0.20289962, -0.22463948, 0.24003603, -0.22635874, 0.22355433, -0.10727884, -0.27763215, 0.12205175, -0.33481315, 0.04747853, 0.22055429, 0.017725615, -0.14218004, -0.27020591, 0.27612484, 0.050210662, 0.041809335, -0.032814298, -0.21339102, 0.22898024, 0.0755317, -0.23465283, -0.109813884, -0.18060842, -0.066495314, 0.22580191, 0.3323758, 0.023281226, 0.07484222, 0.28912178, -0.27487472, 0.121484526, -0.2651789, 0.19090225, -0.003508792, -0.25500044, 0.05072003, -0.07643754, 0.24113968, 0.12844749, 0.24001858, 0.2613778, -0.2603248, 0.08254892, -0.111656696, 0.23785193, 0.32324004, 0.1750177, -0.09340208, -0.12355742, -0.25986317, -0.27915004, 0.07588966, 0.25872853, 0.21791716, 0.2401611, -0.2407115, -0.23268251, -0.30390444, -0.3009561, -0.02586944, 0.16676147, -0.110212825, -0.17888871, 0.33321387, -0.32094577, -0.25499186, 0.25705588, 0.15148534, -0.28999805, 0.0, 0.0]
t = 1.5f0
λ = Dual{ForwardDiff.Tag{OrdinaryDiffEq.OrdinaryDiffEqTag,Float32},Float32,12}[Dual{ForwardDiff.Tag{OrdinaryDiffEq.OrdinaryDiffEqTag,Float32}}(0.09447026, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), Dual{ForwardDiff.Tag{OrdinaryDiffEq.OrdinaryDiffEqTag,Float32}}(1.4116058, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)]

model = Chain(x -> x .^ 3,
    Dense(2, 50, tanh),
    Dense(50, 2))
p, re = Flux.destructure(model)
f(u, p, t) = re(p)(u)

_dy, back = Zygote.pullback(y, p) do u, p
    vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)

Found via:

using DiffEqFlux, OrdinaryDiffEq, Test

u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end
t = range(tspan[1], tspan[2], length=datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))

model = Chain(x -> x .^ 3,
    Dense(2, 50, tanh),
    Dense(50, 2))
neuralde = NeuralODE(model, tspan, Rodas5(), saveat=t, reltol=1e-7, abstol=1e-9)

function predict_n_ode()
    neuralde(u0)
end
loss_n_ode() = sum(abs2, ode_data .- predict_n_ode())

data = Iterators.repeated((), 10)
opt = ADAM(0.1)
cb = function () #callback function to observe training
    display(loss_n_ode())
end

# Display the ODE with the initial parameter values.
cb()

neuralde = NeuralODE(model, tspan, Rodas5(), saveat=t, reltol=1e-7, abstol=1e-9)
ps = Flux.params(neuralde)
loss1 = loss_n_ode()

xx = Ref{Any}()

Flux.train!(loss_n_ode, ps, data, opt, cb=cb)

with the change:

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, W) where TS<:SensitivityFunction
  @unpack sensealg, f = S
  prob = getprob(S)

  isautojacvec = get_jacvec(sensealg)
  if inplace_sensitivity(S)
    if W===nothing
      _dy, back = Zygote.pullback(y, p) do u, p
        out_ = Zygote.Buffer(similar(u))
        f(out_, u, p, t)
        vec(copy(out_))
      end
    else
      _dy, back = Zygote.pullback(y, p) do u, p
        out_ = Zygote.Buffer(similar(u))
        f(out_, u, p, t, W)
        vec(copy(out_))
      end
    end
    tmp1,tmp2 = back(λ)
    dλ[:] .= vec(tmp1)
    dgrad !== nothing && tmp2 !== nothing && (dgrad[:] .= vec(tmp2))
    dy !== nothing && (dy[:] .= vec(_dy))
  else
    if W===nothing
      _dy, back = Zygote.pullback(y, p) do u, p
        vec(f(u, p, t))
      end
    else
      _dy, back = Zygote.pullback(y, p) do u, p
        vec(f(u, p, t, W))
      end
    end
    Main.xx[] = y,p,t,λ
    tmp1, tmp2 = back(λ)
    tmp1 !== nothing && (dλ[:] .= vec(tmp1))
    dy !== nothing && (dy[:] .= vec(_dy))
    dgrad !== nothing && tmp2 !== nothing && (dgrad[:] .= vec(tmp2))
  end
  return
end

ProjectManifest.zip

@ChrisRackauckas
Copy link
Member Author

using ForwardDiff, Zygote, Flux
using ForwardDiff: Dual
y = Float32[0.8564646, 0.21083355]
p = Float32[-0.2548858, -0.264061, 0.06902494, -0.23288882, -0.13166176, 0.25982612, -0.26543534, -0.29349443, 0.31963557, 0.21243489, -0.2755482, -0.04317024, 0.2678376, -0.32618907, -0.11215708, -0.20082082, -0.075056225, -0.3250112, -0.20113565, -0.2580761, 0.03797583, -0.1354496, 0.18161258, 0.3180589, 0.283674, 0.05116003, -0.07082515, 0.12914972, 0.09830813, 0.29125124, 0.32423735, 0.045021717, 0.09604585, 0.007445923, 0.12431481, 0.063025564, 0.30161184, 0.23123802, 0.30304855, -0.18616274, 0.06983177, 0.13229537, 0.26679033, 0.29119095, 0.2044387, -0.1310391, 0.06418764, -0.05145624, 0.28958446, 0.08143681, -0.26594874, 0.258198, -0.16387275, 0.23627394, -0.0025739619, 0.12877232, 0.28468516, 0.14945742, -0.09824067, 0.22391124, 0.2722607, 0.034997866, 0.021131594, -0.058169674, -0.20168333, 0.3310362, 0.29977754, 0.27228144, 0.088294245, 0.17472656, 0.030819716, 0.27218765, 0.042448767, 0.25967237, 0.18181679, 0.2810931, -0.16689181, 0.17927635, 0.32586476, -0.25481033, 0.009913104, 0.20943141, -0.13506782, -0.30059853, -0.084571846, -0.31261674, 0.11608189, 0.084546946, -0.21448077, -0.19288287, -0.22511461, 0.27675447, 0.26279518, 0.061226156, -0.2828123, -0.1394083, -0.16996919, 0.2784961, -0.0039018209, -0.1362619, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.048874624, 0.11889865, 0.01040518, 0.12694769, 0.32327807, 0.13581258, 0.10043003, -0.12258695, 0.32029858, -0.05385616, -0.28262973, -0.29426816, 0.11472986, 0.014853499, -0.055616893, -0.24432188, -0.23522359, -0.07780609, 0.16605335, 0.29451388, -0.32305816, 0.03262463, -0.28862894, 0.054972157, 0.2411704, 0.31518432, 0.2221482, -0.12357236, 0.25466782, 0.03921116, -0.087710164, 0.1594814, -0.33685195, -0.13411506, 0.04239876, 0.260748, 0.15104404, 0.24697773, -0.06698533, -0.039195247, 0.29528958, -0.19330974, -0.32768622, 0.07959501, -0.11285911, -0.031941384, -0.108291335, -0.24830729, -0.08987814, -0.04234308, 0.255426, 0.3337179, 0.18690939, -0.32503495, -0.06603645, -0.17818044, 0.10007081, -0.22569874, 0.030490262, -0.014429291, 0.13864784, 0.100892544, -0.28683808, 0.05345175, -0.12727126, 0.31637886, 0.27381366, 0.026415939, 0.20263642, 0.33452004, -0.3351626, 0.0063842274, -0.26546854, -0.24439275, -0.19636214, 0.3032137, 0.13219267, 0.20853092, -0.05988348, -0.30968776, -0.1278926, 0.33035672, -0.32249796, 0.14322737, -0.29625347, -0.17458698, -0.0010983021, 0.14215776, -0.07308902, -0.19241002, 0.1702171, 0.32165667, 0.27042934, 0.068846, 0.19114906, 0.06528145, -0.31603774, 0.049985882, -0.05847536, 0.04034526, 0.0, 0.0]
t = 1.5f0
λ = ForwardDiff.Dual{ForwardDiff.Tag{Nothing,Float32},Float32,12}[Dual{ForwardDiff.Tag{Nothing,Float32}}(0.87135935, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), Dual{ForwardDiff.Tag{Nothing,Float32}}(1.5225363, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)]

model = Chain(x -> x .^ 3,
    Dense(2, 50, tanh),
    Dense(50, 2))

p,re = Flux.destructure(model)
f(u, p, t) = re(p)(u)
_dy, back = Zygote.pullback(y, p) do u, p
    vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)

@ChrisRackauckas
Copy link
Member Author

Should work once FluxML/Optimisers.jl#65

@ChrisRackauckas
Copy link
Member Author

SciML/NeuralPDE.jl#508 and SciML/DeepEquilibriumNetworks.jl#44 are dependent on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant